# -------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
#
# -------------------------------------------------------------

# Autogenerated By   : src/main/python/generator/generator.py
# Autogenerated From : scripts/builtin/glove.dml

from typing import Dict, Iterable

from systemds.operator import OperationNode, Matrix, Frame, List, MultiReturn, Scalar
from systemds.utils.consts import VALID_INPUT_TYPES


def glove(input: Frame,
          seed: int,
          vector_size: int,
          alpha: float,
          eta: float,
          x_max: float,
          tol: float,
          iterations: int,
          print_loss_it: int,
          maxTokens: int,
          windowSize: int,
          distanceWeighting: bool,
          symmetric: bool):
    """
     Computes the vector embeddings for words in a large text corpus. 
    
    
    
    :param input: 1DInput corpus in CSV format.
    :param seed: Random seed for reproducibility.
    :param vector_size: Dimensionality of word vectors, V.
    :param eta: Learning rate for optimization, recommended value: 0.05.
    :param alpha: Weighting function parameter, recommended value: 0.75.
    :param x_max: Maximum co-occurrence value as per the GloVe paper: 100.
    :param tol: Tolerance value to avoid overfitting, recommended value: 1e-4.
    :param iterations: Total number of training iterations.
    :param print_loss_it: Interval (in iterations) for printing the loss.
    :param maxTokens: Maximum number of tokens per text entry.
    :param windowSize: Context window size.
    :param distanceWeighting: Whether to apply distance-based weighting.
    :param symmetric: Determines if the matrix is symmetric (TRUE) or asymmetric (FALSE).
    :return: The word indices and their word vectors, of shape (N, V). Each represented as a vector, of shape (1,V)
    """

    params_dict = {'input': input, 'seed': seed, 'vector_size': vector_size, 'alpha': alpha, 'eta': eta, 'x_max': x_max, 'tol': tol, 'iterations': iterations, 'print_loss_it': print_loss_it, 'maxTokens': maxTokens, 'windowSize': windowSize, 'distanceWeighting': distanceWeighting, 'symmetric': symmetric}
    return Matrix(input.sds_context,
        'glove',
        named_input_nodes=params_dict)
